{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e1ce83f",
   "metadata": {},
   "source": [
    "# Lesson 5 - Low-Rank Adaptation\n",
    "\n",
    "In this lesson, we're going to explore the idea of serving fine-tuned LLMs trained using Low-Rank Adaptation (LoRA)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f61a7f07",
   "metadata": {},
   "source": [
    "### Import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c054dd0b-8007-4a7e-85b8-823e343633ff",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79174e42-7b9c-48f0-a104-2fbe7e508d98",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# set the seed so we get the same results from here on for each run\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7986fa0",
   "metadata": {},
   "source": [
    "### Create a test Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fde39d-e6b2-4fbf-9a37-7d090f5a30c7",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "class TestModel(torch.nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super().__init__()\n",
    "        self.embedding = torch.nn.Embedding(10, hidden_size)\n",
    "        self.linear = torch.nn.Linear(hidden_size, hidden_size)\n",
    "        self.lm_head = torch.nn.Linear(hidden_size, 10)\n",
    "    \n",
    "    def forward(self, input_ids):\n",
    "        x = self.embedding(input_ids)\n",
    "        x = self.linear(x)\n",
    "        x = self.lm_head(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cada641-571a-432d-b36e-3b1e53b791ed",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# set a reasonably large hidden size to illustrate the small fraction of\n",
    "# params needed to be added for LoRA\n",
    "hidden_size = 1024\n",
    "model = TestModel(hidden_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f449116-f538-45de-a321-91e821b1cbfc",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# dummy inputs\n",
    "input_ids = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37eba5f4-4571-40d0-805c-08eae9682bc0",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "# toy example of a detokenizer. \n",
    "# The vocabulary only consists of 10 words (different colors)\n",
    "detokenizer = [\n",
    "    \"red\",\n",
    "    \"orange\",\n",
    "    \"yellow\",\n",
    "    \"green\",\n",
    "    \"blue\",\n",
    "    \"indigo\",\n",
    "    \"violet\",\n",
    "    \"magenta\",\n",
    "    \"marigold\",\n",
    "    \"chartreuse\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12a198a6",
   "metadata": {},
   "source": [
    "### Reuse the generate token function from Lesson 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b76c5a-7111-4420-b903-c42dce956a4f",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "# this is the same generation step as we saw in lesson 2 (batching)\n",
    "def generate_token(model, **kwargs):\n",
    "    with torch.no_grad():\n",
    "        logits = model(**kwargs)\n",
    "    last_logits = logits[:, -1, :]\n",
    "    next_token_ids = last_logits.argmax(dim=1)\n",
    "\n",
    "    return [detokenizer[token_id] for token_id in next_token_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4356ac2-4af2-481e-b5df-64b3411ed4dd",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# generate one token\n",
    "next_token = generate_token(model, input_ids=input_ids)[0]\n",
    "next_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e00b14-531f-4ee4-9d2d-194a70df5a8c",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# dummy input tensor\n",
    "# shape: (batch_size, sequence_length, hidden_size)\n",
    "X = torch.randn(1, 8, 1024)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef965870",
   "metadata": {},
   "source": [
    "### Let's set up the LoRA computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0db46607-6a46-4ef0-add4-1c78f36276e0",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# LoRA A and B tensors\n",
    "# A has shape (hidden_size, rank)\n",
    "# B has shape (rank, hidden_size)\n",
    "lora_a = torch.randn(1024, 2)\n",
    "lora_b = torch.randn(2, 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0892c319-839c-484d-8a7c-fb10a75815be",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "W = model.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a3b290-a4e6-4c7d-95fd-5004b1b2cf8b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "W.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fd62e28-42e7-4f64-9722-4e1e9d37daf0",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "W2 = lora_a @ lora_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dfe0321-4229-4d47-a86d-4876cd473cf9",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "W2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56f39935-8b11-4e6b-8f57-db98b8b88f12",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "# Compare number of elements of A and B with number of elements of W\n",
    "# W here has shape (hidden_size, hidden_size)\n",
    "lora_numel = lora_a.numel() + lora_b.numel()\n",
    "base_numel = W.numel()\n",
    "print(\"|A+B| / |W|:\", lora_numel / base_numel)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ac80554",
   "metadata": {},
   "source": [
    "### Let's run the LoRA computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be6b659a-ecb2-492d-a8f9-e7746aa8e471",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "# compute the output of X @ W (the original linear layer)\n",
    "base_output = model.linear(X)\n",
    "\n",
    "# compute the output of X @ A @ B (the added lora adapter)\n",
    "lora_output = X @ lora_a @ lora_b\n",
    "\n",
    "# sum them together\n",
    "total_output = base_output + lora_output\n",
    "\n",
    "# output should have the same shape as the original output:\n",
    "# (batch_size, sequence_length, hidden_size)\n",
    "total_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b301e22-32ac-4096-8d2c-78dda4471e97",
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "class LoraLayer(torch.nn.Module):\n",
    "    def __init__(self, base_layer, r):\n",
    "        super().__init__()\n",
    "        self.base_layer = base_layer\n",
    "        \n",
    "        d_in, d_out = self.base_layer.weight.shape\n",
    "        self.lora_a = torch.randn(d_in, r)\n",
    "        self.lora_b = torch.randn(r, d_out) \n",
    "        \n",
    "    def forward(self, x):\n",
    "        y1 = self.base_layer(x)\n",
    "        y2 = x @ self.lora_a @ self.lora_b\n",
    "        return y1 + y2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a96ee85-de4d-497e-98a1-fb41329ffb1c",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# wrap the linear layer of our toy model, use rank 2\n",
    "lora_layer = LoraLayer(model.linear, 2)\n",
    "lora_layer(X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "519bcdbd",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model.linear = lora_layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec773d5f-2146-4f2e-a724-ef1364f8cf73",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3781f60",
   "metadata": {},
   "source": [
    "### Let's try the generate token after adding the LoRA layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b14005f6-33de-4c13-bfc7-4592282339eb",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "next_token = generate_token(model, input_ids=input_ids)\n",
    "next_token[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
